#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time   : 2021/9/7 23:16
# @Author : cjw
from typing import Optional
from sqlalchemy.orm import Session
from sqlalchemy import select, update, delete, func

from app.models.case import Case
from app.schemas.case import CaseCreate, Case as SCase


def get_case(db: Session, c_id: int):
	"""
	根据用例id获取用例
	:param db: 数据库连接
	:param c_id: 用例id
	:return:
	"""
	return db.get(Case, c_id)


def get_case_count(db: Session):
	"""获取用例表中的数据数量"""
	stmt = select(func.count(Case.id).label('count'))
	return db.execute(stmt).mappings().one()


def get_case_by_node_id(db: Session, node_id: str):
	"""
	根据用例id获取用例
	:param db: 数据库连接
	:param node_id: 用例执行的node id
	:return:
	"""
	stmt = select(Case).where(Case.node_id == node_id)
	return db.execute(stmt).scalar_one_or_none()


def get_cases(db: Session, skip: int = 0, limit: int = 10):
	"""
	获取多个测试任务
	:param db:
	:param skip:
	:param limit:
	:return:
	"""
	stmt = select(Case).offset(skip).limit(limit)
	return [st.as_dict() for st in db.execute(stmt).scalars()]


def create_case(db: Session, case: CaseCreate):
	"""
	创建测试任务
	:param db:
	:param case:
	:return:
	"""
	db_case = Case(**case.dict())
	db.add(db_case)
	db.commit()
	return db_case


def update_case(db: Session, case: SCase):
	"""
	更新测试任务
	:param db:
	:param case:
	:return:
	"""
	stmt = update(Case). \
		where(Case.id == case.id). \
		values(**case.dict()). \
		execution_options(synchronize_session='fetch')
	result = db.execute(stmt)
	db.commit()
	return result.rowcount


def delete_case(db: Session, c_id: Optional[int] = None, node_id: Optional[str] = None):
	"""
	删除测试任务
	:param db:
	:param c_id:
	:param node_id:
	:return:
	"""
	if c_id:
		stmt = delete(Case).where(Case.id == c_id)
	else:
		stmt = delete(Case).where(Case.node_id == node_id)
	result = db.execute(stmt)
	db.commit()
	return result.rowcount
